package org.hepeng.workx.spring.security.web;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.joor.Reflect;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.DefaultSecurityFilterChain;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.debug.DebugFilter;
import org.springframework.security.web.firewall.FirewalledRequest;
import org.springframework.security.web.firewall.HttpFirewall;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

/**
 * @author he peng
 */

@Slf4j
public class DynamicSecurityIgnoringFilterChainProxy extends FilterChainProxy {

    private final ThreadLocal<List<SecurityFilterChain>> FILTER_CHAIN_THREAD_LOCAL =
            new InheritableThreadLocal<List<SecurityFilterChain>>() {

                @Override
                protected List<SecurityFilterChain> initialValue() {
                    return Reflect.on(DynamicSecurityIgnoringFilterChainProxy.this).get("filterChains");
                }
            };

    private WebSecurityIgnoringUrlLoader ignoringUrlLoader = new EmptyWebSecurityIgnoringUrlLoader();

    public DynamicSecurityIgnoringFilterChainProxy() {}

    public DynamicSecurityIgnoringFilterChainProxy(SecurityFilterChain chain) {
        super(chain);
    }

    public DynamicSecurityIgnoringFilterChainProxy(List<SecurityFilterChain> filterChains) {
        super(filterChains);
    }

    public DynamicSecurityIgnoringFilterChainProxy(SecurityFilterChain chain, WebSecurityIgnoringUrlLoader ignoringUrlLoader) {
        super(chain);
        this.ignoringUrlLoader = ignoringUrlLoader;
    }

    public DynamicSecurityIgnoringFilterChainProxy(List<SecurityFilterChain> filterChains, WebSecurityIgnoringUrlLoader ignoringUrlLoader) {
        super(filterChains);
        this.ignoringUrlLoader = ignoringUrlLoader;
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        String FILTER_APPLIED = Reflect.on(this).get("FILTER_APPLIED");
        boolean clearContext = request.getAttribute(FILTER_APPLIED) == null;
        try {
            if (clearContext) {
                try {
                    request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
                    doFilterInternal(request, response, chain);
                }
                finally {
                    SecurityContextHolder.clearContext();
                    request.removeAttribute(FILTER_APPLIED);
                }
            }
            else {
                doFilterInternal(request, response, chain);
            }
        } finally {
            FILTER_CHAIN_THREAD_LOCAL.remove();
        }
    }

    private void doFilterInternal(ServletRequest request, ServletResponse response,
                                  FilterChain chain) throws IOException, ServletException {

        applyFilterChains((HttpServletRequest) request);
        HttpFirewall firewall = Reflect.on(this).get("firewall");
        FirewalledRequest fwRequest = firewall
                .getFirewalledRequest((HttpServletRequest) request);
        HttpServletResponse fwResponse = firewall
                .getFirewalledResponse((HttpServletResponse) response);

        List<Filter> filters = getFilters(fwRequest);

        if (filters == null || filters.size() == 0) {
            if (log.isDebugEnabled()) {
                log.debug(UrlUtils.buildRequestUrl(fwRequest)
                        + (filters == null ? " has no matching filters"
                        : " has an empty filter list"));
            }

            fwRequest.reset();

            chain.doFilter(fwRequest, fwResponse);

            return;
        }

        try {
            Class<?> virtualFilterChainClass =
                    Class.forName("org.springframework.security.web.FilterChainProxy$VirtualFilterChain");
            Object[] args = {fwRequest , chain , filters};
            Class<?>[] parameterTypes = {FirewalledRequest.class , FilterChain.class , List.class};

            Constructor<?> constructor = virtualFilterChainClass.getDeclaredConstructor(parameterTypes);
            constructor.setAccessible(true);
            FilterChain vfc = (FilterChain) constructor.newInstance(args);
            vfc.doFilter(fwRequest, fwResponse);
        } catch (Throwable e) {
        }
    }

    private List<Filter> getFilters(HttpServletRequest request) {
        List<SecurityFilterChain> filterChains = FILTER_CHAIN_THREAD_LOCAL.get();
        if (CollectionUtils.isNotEmpty(filterChains)) {
            for (SecurityFilterChain chain : FILTER_CHAIN_THREAD_LOCAL.get()) {
                if (chain.matches(request)) {
                    return chain.getFilters();
                }
            }
        }

        return null;
    }

    private List<SecurityFilterChain> getFilterChains(HttpServletRequest request) {
        List<String> ignoringUrls = new ArrayList<>();
        List<String> loadIgnoringUrls = this.ignoringUrlLoader.loadIgnoringUrl(request);
        if (CollectionUtils.isNotEmpty(loadIgnoringUrls)) {
            ignoringUrls.addAll(this.ignoringUrlLoader.loadIgnoringUrl(request));
        }
        for (SecurityFilterChain chain : FILTER_CHAIN_THREAD_LOCAL.get()) {
            RequestMatcher requestMatcher = Reflect.on(chain).get("requestMatcher");
            if (requestMatcher instanceof AntPathRequestMatcher) {
                String pattern = ((AntPathRequestMatcher) requestMatcher).getPattern();
                ignoringUrls.add(pattern);
            }
        }

        ignoringUrls = ignoringUrls.stream().distinct().collect(Collectors.toList());
        List<SecurityFilterChain> newChains = new ArrayList<>();

        if (CollectionUtils.isNotEmpty(ignoringUrls)) {
            ignoringUrls.forEach(ignoringUrl -> {
                AntPathRequestMatcher requestMatcher = new AntPathRequestMatcher(ignoringUrl);
                SecurityFilterChain sfc = new DefaultSecurityFilterChain(requestMatcher);
                newChains.add(sfc);
            });
        }

        for (SecurityFilterChain chain : FILTER_CHAIN_THREAD_LOCAL.get()) {
            RequestMatcher requestMatcher = Reflect.on(chain).get("requestMatcher");
            if (! (requestMatcher instanceof AntPathRequestMatcher)) {
                newChains.add(chain);
            }
        }

        return newChains;
    }

    private void applyFilterChains(HttpServletRequest request) {
        FILTER_CHAIN_THREAD_LOCAL.set(getFilterChains(request));
    }

    public interface WebSecurityIgnoringUrlLoader {
        List<String> loadIgnoringUrl(HttpServletRequest request);
    }

    public static class EmptyWebSecurityIgnoringUrlLoader implements DynamicSecurityIgnoringFilterChainProxy.WebSecurityIgnoringUrlLoader {
        @Override
        public List<String> loadIgnoringUrl(HttpServletRequest request) {
            return Collections.emptyList();
        }
    }

    public static class FilterChainProxyPostProcessor implements BeanPostProcessor {

        @Autowired(required = false)
        private DynamicSecurityIgnoringFilterChainProxy.WebSecurityIgnoringUrlLoader ignoringUrlLoader;

        @Override
        public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
            return bean;
        }

        @Override
        public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {

            if (bean instanceof DebugFilter
                    || bean instanceof FilterChainProxy) {
                DynamicSecurityIgnoringFilterChainProxy.WebSecurityIgnoringUrlLoader ignoringUrlLoader = Objects.isNull(this.ignoringUrlLoader)
                        ? new DynamicSecurityIgnoringFilterChainProxy.EmptyWebSecurityIgnoringUrlLoader() : this.ignoringUrlLoader;
                if (bean instanceof DebugFilter) {
                    FilterChainProxy fcp = Reflect.on(bean).get("fcp");
                    DynamicSecurityIgnoringFilterChainProxy filterChainProxy = new DynamicSecurityIgnoringFilterChainProxy(
                                    fcp.getFilterChains() , ignoringUrlLoader);
                    Reflect.on(bean).set("fcp" , filterChainProxy);
                } else {
                    bean = new DynamicSecurityIgnoringFilterChainProxy(((FilterChainProxy)bean).getFilterChains() , ignoringUrlLoader);
                }
                return bean;
            }
            return bean;
        }
    }
}
